import argparse

import numpy as np

import logging
import torch
from wrench.logging import LoggingHandler
from wrench.dataset import load_dataset
from wrench.labelmodel.RACH_Space import RACH_Space_Algorithm
from wrench.endmodel import EndClassifierModel

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])

logger = logging.getLogger(__name__)


device = torch.device('cpu')

#### Load dataset
dataset_path = '../datasets/'
data = 'spouse'

train_data, valid_data, test_data = load_dataset(
    dataset_path,
    data,
    extract_feature=False,
)

label_model = RACH_Space_Algorithm()
label_model.fit(dataset_train=train_data)

# print(label_model.predict_proba(train_data))
acc = label_model.test(test_data, 'acc')
logger.info("RACH-Space test accuracy:" f'label model test acc: {acc}')

f1_binary = label_model.test(test_data, 'f1_binary')
logger.info("RACH-Space f1 binary:" f'label model test f1_binary: {f1_binary}')


